{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Definition of KL divergence"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\\begin{align*}\n",
    "KL(P \\parallel Q)\n",
    "&= KL(P(X) \\parallel Q(X)) \\\\\n",
    "&= \\sum_{x} P(x) \\log \\frac{P(x)}{Q(x)}\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convention: $0\\log 0 = 0$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (a) Nonnegativity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prove\n",
    "\n",
    "$$ \\forall P, Q \\;\\;\\; KL(P\\parallel Q) \\ge 0$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\\begin{align*}\n",
    "KL(P \\parallel Q) \n",
    "&= -\\sum_{x} P \\log \\frac{Q}{P} \\\\\n",
    "&= E[-\\log \\frac{Q}{P}] \\\\\n",
    "&\\ge -\\log E[\\frac{Q}{P}] \\\\\n",
    "&= - \\log \\sum_x P\\frac {Q}{P} \\\\\n",
    "&= - \\log \\sum_x Q \\\\\n",
    "&= - \\log 1 \\\\\n",
    "&= 0\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The equality is a result after applying Jensen's inequality with $-\\log$ being a convex function."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prove\n",
    "\n",
    "$$ KL(P \\parallel Q) = 0 \\text{ if and only if } P = Q$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Apparently, if $P = Q$, $KL(P \\parallel Q) = \\sum 0 = 0$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, when $KL(P\\parallel Q) = 0$, given $-\\log x$ is strictly convex, then $\\frac{P}{Q} = E[\\frac{P}{Q}] = 1$ with probability  1 (Lecture notes on Jensen's inequality), so $P = Q$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (b) Chain rule for KL divergence"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Definition of KL divergence between 2 conditional distributions:\n",
    "\n",
    "$$\n",
    "KL(P(X|Y) \\parallel Q(X|Y)) = \\sum_y P(y) \\bigg (\\sum_x P(x|y) \\log \\frac{P(x|y)}{Q(x|y)} \\bigg)\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prove \n",
    "\n",
    "$$\n",
    "KL(P(X, Y) \\parallel Q(X, Y)) = KL(P(X) \\parallel Q(X)) + KL(P(Y|X) \\parallel Q(Y|X))\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Expand the item in the RHS,"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\\begin{align*}\n",
    "KL(P(X) \\parallel Q(X)) &= \\sum_{x} P(x) \\log \\frac{P(x)}{Q(x)}\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\\begin{align*}\n",
    "KL(P(Y|X) \\parallel Q(Y|X)) &= \\sum_{x} P(x) \\bigg (\\sum_y Pd(y|x) \\log \\frac{P(y|x)}{Q(y|x)} \\bigg )\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add up the two and transform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\\begin{align*}\n",
    "KL(P(X) \\parallel Q(X)) + KL(P(Y|X) \\parallel Q(Y|X)) \n",
    "&= \\sum_x P(x) \\bigg( \\log \\frac{P(x)}{Q(x)} + \\sum_y P(y|x) \\log \\frac{P(y|x)}{Q(y|x)} \\bigg) \\\\\n",
    "&= \\sum_x P(x) \\bigg( \\sum_y P(y|x) \\bigg( \\log \\frac{P(x)}{Q(x)} + \\log \\frac{P(y|x)}{Q(y|x)} \\bigg) \\bigg) \\\\\n",
    "&= \\sum_x P(x) \\bigg( \\sum_y P(y|x) \\bigg( \\log \\frac{P(x) P(y|x)}{Q(x) Q(y|x)} \\bigg) \\bigg) \\\\\n",
    "&= \\sum_x P(x) \\bigg( \\sum_y P(y|x) \\bigg( \\log \\frac{P(x, y)}{Q(x, y)} \\bigg) \\bigg) \\\\\n",
    "&= \\sum_x P(x, y) \\log \\frac{P(x, y)}{Q(x, y)} \\\\\n",
    "&= KL(P(X, Y) \\parallel Q(X, Y)) \\\\\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The second equality is a result of distributing $\\log \\frac{P(x)}{Q(x)}$ over $P(y|x)$. This is valid because $\\sum_y P(y|x) = 1$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (c) KL and maximum likelihood"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given \n",
    "\n",
    "$$ \\hat P(x) = \\frac{1}{m} \\sum_{i=1}^{m} 1 \\{x^{(i)} = x\\}$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prove \n",
    "\n",
    "$$ \\arg\\!\\min_{\\theta} KL(\\hat P \\parallel P_{\\theta}) = \\arg\\!\\max_{\\theta} \\sum_{i=1}^{m} \\log P_{\\theta} (x^{(i)})$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\\begin{align*}\n",
    "KL(\\hat P \\parallel P_{\\theta}) \n",
    "&= \\sum_x \\hat P(x) \\log \\frac{\\hat P(x)}{P_{\\theta}(x)} \\\\\n",
    "&= - \\sum_x \\hat P(x) \\log \\frac{P_{\\theta}(x)}{\\hat P(x)} \\\\\n",
    "&= - \\sum_x \\frac{1}{m} \\sum_{i=1}^{m} 1 \\{x^{(i)} = x\\} \\log \\frac{P_{\\theta}(x)}{\\frac{1}{m} \\sum_{i=1}^{m} 1 \\{x^{(i)} = x\\}} \\\\\n",
    "&= - \\frac{1}{m} \\sum_{i=1}^{m} \\log \\frac{P_{\\theta}(x^{(i)})}{\\frac{1}{m} \\sum_{i=1}^{m} } \\\\\n",
    "&= - \\frac{1}{m} \\sum_{i=1}^{m} \\log P_{\\theta}(x^{(i)}) \\\\\n",
    "&= - \\frac{1}{m} \\text{log-likelihood}\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* The 3rd equality replaced $\\hat P(x)$ with its definition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since $\\frac{1}{m}$ is a constant, so minimize the LHS is equivalent to maximizing log-likelihood at the RHS."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
